[NPU] Add NPU Fused MoE kernel#1183
Merged
Tcc0403 merged 6 commits intolinkedin:mainfrom Apr 24, 2026
Merged
Conversation
Contributor
Author
Contributor
Author
|
@Tcc0403 This PR is ready for review. |
Tcc0403
reviewed
Apr 21, 2026
Collaborator
Tcc0403
left a comment
There was a problem hiding this comment.
LGTM, just a tiny issue
Comment on lines
-234
to
+241
| torch.cuda.synchronize() | ||
| if device == "cuda": | ||
| torch.cuda.synchronize() | ||
| elif device == "npu": | ||
| torch.npu.synchronize() | ||
|
|
Collaborator
There was a problem hiding this comment.
Great catch, we also have CPU support. Could you add it?
zheliuyu
commented
Apr 22, 2026
Comment on lines
+160
to
+165
| if device == "cuda": | ||
| torch.cuda.synchronize() | ||
| elif device == "npu": | ||
| torch.npu.synchronize() | ||
| else: | ||
| torch.cpu.synchronize() |
Contributor
Author
There was a problem hiding this comment.
@Tcc0403 Thanks for the suggestion. torch provides a cpu equivalent, so I've added it here.
Collaborator
There was a problem hiding this comment.
Sorry typo, meant to be xpu not cpu 😅
Contributor
Author
There was a problem hiding this comment.
Got it. Added torch.xpu.synchronize, please take another look.
zheliuyu
commented
Apr 23, 2026
Comment on lines
+28
to
+47
| def compute_routing_metadata(topk_indices: torch.Tensor, E: int, block_m_token: int = BLOCK_M_TOKEN): | ||
| """Compute token→expert routing permutation metadata via 3 Triton kernels. | ||
|
|
||
| Also computes GPU tile metadata (tile_row_start, tile_expert) inside | ||
| Kernel 3 — no CPU loop, one .item() sync for num_m_tiles allocation. | ||
|
|
||
| Args: | ||
| topk_indices: (T, K) int32 — pre-computed top-k expert indices per token | ||
| E: number of experts | ||
| block_m_token: BLOCK_M for token-dimension tiling (default BLOCK_M_TOKEN) | ||
|
|
||
| Returns: | ||
| expert_token_count: (E,) int32 | ||
| expert_start_idx: (E+1,) int32 | ||
| x_gather_idx: (TK,) int32 | ||
| s_scatter_idx: (TK,) int32 | ||
| s_reverse_scatter_idx: (TK,) int32 | ||
| tile_row_start: (num_m_tiles,) int32 — absolute row_start per M-tile | ||
| tile_expert: (num_m_tiles,) int32 — expert index per M-tile | ||
| """ |
Contributor
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.






Motivation
This pr ports
fused_moe.pyandfused_moe_kernels.pyto an NPU-affine implementation while preserving the original math. The computational definition is unchanged: forward remainsW1 (gate/up) -> SwiGLU -> W2 -> token-weighted gather, and backward still followsdA' = dO @ W2^Tto produced_pre_act / dS / dW2 / dX / dW1.The main changes are execution-strategy optimizations for NPU.
Note: Use the Skill
For this fused_moe kernel migration, we followed the skill document from #1197.
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence🤖 Generated with: cursor.